import jax 
import functools as ftools
import jax.numpy as np

@jax.jit
def _concatenate(x, y):
    if np.ndim(x) == 0: x = np.array([x])
    if np.ndim(y) == 0: y = np.array([y])
    return np.concatenate([x, y], axis = 0)

@jax.jit
def _concatenate_trees(*trees):
    return jax.tree_multimap(ftools.partial(_concatenate), *trees)

@jax.jit
def _stack(x, y):
    if np.ndim(x) == 0: x = np.array([x])
    if np.ndim(y) == 0: y = np.array([y])
    if x.shape == y.shape:
        return np.stack([x, y], axis = 0)
    elif np.ndim(x) == (np.ndim(y)+1):
            return np.concatenate([x, np.expand_dims(y, 0)], axis = 0)
    else:
        print(x.shape)
        print(y.shape)
        raise ValueError
@jax.jit
def _stack_trees(*trees):
    return jax.tree_multimap(ftools.partial(_stack), *trees)

@jax.jit
def _append_results(rez):
    return dict(zip(
                rez.keys(), 
                map(ftools.partial(ftools.reduce, _stack_trees),
                    rez.values())
                ))
